#!/usr/bin/env pytho      
# -*- coding: utf-8 -*-
# @Author  : junpeng_chen
# @Time    : 2023/7/26 10:39
# @File    : main
# @annotation    : 值识别任务主函数，conda环境为163:2041 - smartbi_merge
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "6"
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
import json
import requests

from modules.cond_conn_op_recognition import CondConnOpRecognition
from modules.group_by_recognition import GroupByRecognition
from modules.limit_recognition import LimitRecognition
from modules.literal_recognition import LiteralRecognition
from modules.mdx_recognition import MDXRecognition
from modules.measure_recognition import MeasureRecognition
from modules.measure_type_recognition import MeasureTypeRecognition
from modules.order_by_recognition import OrderByRecognition
from modules.others_recognition import OthersRecognition
from modules.select_package import Sel_Package
from modules.time_recognition import TimeRecognition
from modules.agg_recognition import AggRecognition
from modules.header_recognition import HeaderRecognition
from modules.norm_recognition import NormRecognition
from modules.col_and_row_regulator import ColAndRowRegulator
from utils.log import logger
from utils.chatglm_utils import ChatGLMUtils
from utils.timer import Timer



class ValRegMain:
    def __init__(self, model=None, data_dict_path=None, headers_path=None):
        try:
            self.headers_path = headers_path
            self.data_dict_path = data_dict_path
            self.model = model
            if self.model == None:
                print("模型赋值失败")
                raise Exception
            self.sel_reg = None
            self.data_dict = None
            self.literal_reg = None
            self.measure_type_reg = None
            self.limit_reg = None
            self.time_reg = None
            self.measure_reg = None
            self.other_reg = None
            self.agg_reg = None
            self.mdx_reg = None
            self.group_by_reg = None
            self.order_by_reg = None
            self.cond_conn_op_reg = None
            self.col_and_row_regulator = None
        except Exception as err:
            logger.error("[ValRegMain]Model initialization err[{}]: {}".format(type(err), str(err)))
            raise
        
    def configure_components(self):
        """
        配置各子任务组件
        """
        try:
            # header选择
            self.header_reg = HeaderRecognition(self.model, self.headers_path)

            # query规范化
            self.norm_reg = NormRecognition(self.model, self.headers_path)

            # select相关组件加载
            self.sel_reg = Sel_Package()

            # 将原数据字典中的header列表转换成字典形式，方便根据header查询对应属性
            self.data_dict = self.sel_reg.list_to_dict()

            # 字面量识别相关组件加载
            self.literal_reg = LiteralRecognition(self.model)

            # measure识别组件
            self.measure_type_reg = MeasureTypeRecognition(self.data_dict_path)

            # limit识别
            self.limit_reg = LimitRecognition()

            # 时间维度条件格式转换
            self.time_reg = TimeRecognition(self.model)

            # 数值条件识别与格式转换
            self.measure_reg = MeasureRecognition(self.model)

            # 其他类型字面量条件识别与格式转换
            self.other_reg = OthersRecognition(self.data_dict_path)

            # agg识别
            self.agg_reg = AggRecognition(self.model, self.data_dict_path)

            # mdx识别
            self.mdx_reg = MDXRecognition(self.model)

            # group_by识别
            self.group_by_reg = GroupByRecognition()

            # order_by识别
            self.order_by_reg = OrderByRecognition()

            # cond_conn_op识别
            self.cond_conn_op_reg = CondConnOpRecognition()
            
            # col与row字段的识别
            self.col_and_row_regulator = ColAndRowRegulator()
        except Exception as err:
            logger.error("[ValRegMain]Components initialization err[{}]: {}".format(type(err), str(err)))
            raise

# 获取当前文件所在目录的绝对路径
current_directory = os.path.dirname(os.path.abspath(__file__))
model = ChatGLMUtils()
model.load_model(local_loading=True, 
model_path="/home/cike/ytc/GLM2/checkpoints/Lora7_25_combine/checkpoint-64000/pytorch_model.bin",
cuda_index=7)
if model:
    val_reg_main = ValRegMain(model=model, data_dict_path=current_directory+"/data/data_dict.json", headers_path=current_directory+"/data/new_header.json")
    val_reg_main.configure_components()
else:
    print("模型加载失败")
    raise Exception

def main_function(query: str) -> dict:
    """
    值识别任务-联调主函数
    :param query: 自然语言查询
    :return:
    """
    timer = Timer()
    global_timer = Timer()
    global_timer.start()
    conds = []
    having = []
    '''
    查询问题的规范化
    '''
    timer.start()
    query = val_reg_main.norm_reg.recognize(query)
    timer.stop('norm')
    '''
    获得对应的列名
    '''
    timer.start()
    column = val_reg_main.header_reg.recognize(query)
    timer.stop('header')
    '''
    获取sel识别字段
    '''
    # 输入样例，封装sel并返回输出结果
    timer.start()
    try:
        sel = val_reg_main.sel_reg.select_package(query, column, val_reg_main.data_dict)["sel"]
    except Exception as err:
        logger.error("[Main]select error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('select')
    '''
    字面量与操作符识别
    '''
    timer.start()
    try:
        literals_meta = val_reg_main.literal_reg.recognize(query, sel)
        literals = literals_meta['processed']['literals']
        literals_origin = literals_meta['origin']
    except Exception as err:
        logger.error("[Main]literals&opts error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('literals&opts')
    '''
    measure识别
    '''
    timer.start()
    try:
        measure = val_reg_main.measure_type_reg.measure_type(sel)['measure']
    except Exception as err:
        logger.error("[Main]measure error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('measure')
    '''
    limit识别
    '''
    timer.start()
    try:
        limit = val_reg_main.limit_reg.limit(query)['limit']
    except Exception as err:
        logger.error("[Main]limit error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('limit')
    '''
    时间维度条件格式转换
    '''
    timer.start()
    try:
        # 根据literal的结果过滤时间条件。literals是子任务1输出的格式化数据
        time_literals = val_reg_main.time_reg.filter(literals)
        # 调用转换方法
        time_conds_meta = val_reg_main.time_reg.transform(query, time_literals)
        time_conds = time_conds_meta['processed']
        time_conds_origin = time_conds_meta['origin']
        # conds为最终的类SQL中的conds字段
        conds.extend(time_conds)
    except Exception as err:
        logger.error("[Main]time condition transformation error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('time condition transformation')
    '''
    其他类型字面量转换
    '''
    timer.start()
    try:
        other_literals = val_reg_main.other_reg.filter(literals)
        other_conds = val_reg_main.other_reg.transform(query, other_literals)
        conds.extend(other_conds)
    except Exception as err:
        logger.error("[Main]other condition transformation error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('other condition transformation')
    '''
    agg
    '''
    timer.start()
    try:
        agg_meta = val_reg_main.agg_reg.recognition(query, sel, measure)
        agg = agg_meta['processed']['agg']
        agg_origin =  agg_meta['origin']
    except Exception as err:
        logger.error("[Main]agg error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('agg_reg')
    '''
    数值条件格式转换
    '''
    timer.start()
    try:
        # 根据literal的结果过滤数值条件。query是自然语言查询，literals是子任务1输出的格式化数据
        measure_literals,measure_agg = val_reg_main.measure_reg.filter(query, literals,agg)
        # 调用转换方法
        measure_meta = val_reg_main.measure_reg.transform(query, measure_literals, measure,measure_agg)
        measure_conds, measure_having = measure_meta['processed']
        measure_origin = measure_meta['origin']
        # conds为最终的类SQL中的conds字段，having为最终的类SQL中的having字段
        conds.extend(measure_conds)
        having.extend(measure_having)
    except Exception as err:
        logger.error("[Main]measure condition transformation error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('measure condition transformation')
    '''
    mdx
    '''
    timer.start()
    try:
        mdx_meta = val_reg_main.mdx_reg.recognition(query, sel, measure)
        mdx = mdx_meta['processed']['mdx']
        mdx_origin = mdx_meta['origin']
    except Exception as err:
        logger.error("[Main]agg error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('mdx_reg')
    '''
    group_by
    '''
    timer.start()
    try:
        group_by = val_reg_main.group_by_reg.group_by(query, sel, agg, measure)
    except Exception as err:
        logger.error("[Main]group by error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('group by')
    '''
    order_by
    '''
    try:
        order_by = val_reg_main.order_by_reg.order_by(query, sel, having, measure)
    except Exception as err:
        logger.error("[Main]order by error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('order by')
    '''
    cond_conn_op
    '''
    timer.start()
    try:
        cond_conn_op = val_reg_main.cond_conn_op_reg.conn_op_recognition(conds)
    except Exception as err:
        logger.error("[Main]cond_conn_op error[{}]: {}".format(type(err), str(err)))
        raise
    timer.stop('cond_conn_op')
    
    '''
    col 与 row
    '''
    try:
        col, row = val_reg_main.col_and_row_regulator.col_and_row(sel, measure)
    except Exception as err:
        logger.error("[Main]col and row error[{}]: {}".format(type(err), str(err)))
        raise

    response = dict()
    temp_response = dict()

    #填充发给数据库的json
    response['sel'] = sel
    response['agg'] = agg
    response['conds'] = conds
    response['group_by'] = group_by
    response['order_by'] = order_by['order_by'] #先这样改下
    response['limit'] = limit
    response['having'] = having
    response['cond_conn_op'] = cond_conn_op['cond_conn_op'] #先这样改下
    response['measure'] = measure
    # 现阶段需要拼接has_time字段、forceType字段、type字段、rowNotEmpty字段和data_source字段，暂时保持默认值
    response['row'] = row # row的话先把sel里面的dimension给塞进去
    response['col'] = col # col先不管
    response['has_time'] = 'true'
    response['forceType'] = ''
    response['type'] = 'TABLE_CROSS'
    response['rowNotEmpty'] = 'true'
    response['data_source'] = 'AUGMENTED_DATASET'

    #填充临时json
    temp_response['agg'] = agg_origin
    temp_response['literal_and_operator'] = literals_origin
    temp_response['measure_conds'] = measure_origin
    temp_response['time_conds'] = time_conds_origin
    temp_response['mdx'] = mdx_origin
    temp_response['header_reg'] = column
    temp_response['normalization'] = query
    
    global_timer.stop('Total')
    return response, temp_response

# app = Flask(__name__)
# @app.route('/getSQLs', methods=["POST"])
# def getSQLs():
#     try:
#         header_data = request.get_json()
#         if "query" not in header_data or "column" not in header_data:
#             raise Exception
#         query = header_data['query']
#         column = header_data['column']
#     except Exception as err:
#         logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
#         raise
    
#     # 核心识别函数
#     sqls = main_function(query, column)
#     return jsonify(sqls)

#理论上新的系统不需要再给Header了
# @app.route('/getSQLs', methods=["POST"])
# def getSQLs():
#     try:
#         header_data = request.get_json()
#         print(header_data)
#         if "query" not in header_data:
#             raise Exception
#         query = header_data['query']
#     except Exception as err:
#         logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
#         raise   
#     # 核心识别函数
#     sqls = main_function(query)
#     return jsonify(sqls)

def Post(url, data):
    response = requests.post(url, data=data)
    if response.status_code == 200:
        print('请求成功！')
        #print('响应内容：', response.text)
    else:
        print('请求失败，状态码：', response.status_code)
    json_data = json.loads(response.text)
    return json_data

const_token = None
app = Flask(__name__)
cors = CORS(app)
@app.route('/getSQLs', methods=["POST"])
def getSQLs():
    global const_token
    url_login = 'http://proj.smartbi.com.cn:9070/aiweb/api/v1/login'
    url_sql = 'http://proj.smartbi.com.cn:9070/aiweb/integration/api/v1/query_with_nl2sql'
    try:
        header_data = request.get_json()['data']
        print(header_data)
        if "question" not in header_data:
            raise Exception
        query = header_data['question']
    except Exception as err:
        logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
        raise   
    # 核心识别函数
    sqls, origin = main_function(query)
    # sqls = json_sql
    str_sql = json.dumps(sqls)
    
    if not const_token: 
        data_login = {
            'userName': 'huagong1',
            'password': 'huagong1'
        }
        token_data = Post(url_login, data_login)
        const_token = token_data['token']
    data_sql = {
        'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
        'token': const_token,
        'nl2sql': str_sql
    }
    result = Post(url_sql, data_sql)
    code = result['code']
    if 'result' in result and result['result'] is not None:
        result = result['result']      
        if 'html' in result:
            result = result['html']
            #result = json.loads(result)

    if code == -2:
        data_login = {
            'userName': 'huagong1',
            'password': 'huagong1'
        }
        token_data = Post(url_login, data_login)
        const_token = token_data['token']
        data_sql = {
            'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
            'token': const_token,
            'nl2sql': str_sql
        }
        result = Post(url_sql, data_sql)
        code = result['code']
        if 'result' in result and result['result'] is not None:
            result = result['result']      
            if 'html' in result:
                result = result['html']
                #result = json.loads(result)
     
    final_data_list = [sqls]
    final_answer_list = [result]
    final_origin_list = [origin]
    final_result = {
        'data': final_data_list,
        'answer': final_answer_list,
        'model_output': final_origin_list,
        'status': 0
    }  
    return jsonify(final_result)

@app.route('/resendJSON', methods=["POST"])
def resendJSON():
    global const_token
    url_login = 'http://proj.smartbi.com.cn:9070/aiweb/api/v1/login'
    url_sql = 'http://proj.smartbi.com.cn:9070/aiweb/integration/api/v1/query_with_nl2sql'
    str_sql = ''
    try:
       data = request.get_json()['data']
       str_sql = json.dumps(data)
    except Exception as err:
        logger.error("json error[{}]: {}".format(type(err),str(err)))
        raise   
    #print(str_sql)
    if not const_token: 
        data_login = {
            'userName': 'huagong1',
            'password': 'huagong1'
        }
        token_data = Post(url_login, data_login)
        const_token = token_data['token']
    data_sql = {
        'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
        'token': const_token,
        'nl2sql': str_sql
    }
    result = Post(url_sql, data_sql)
    code = result['code']
    if 'result' in result and result['result'] is not None:
        result = result['result']      
        if 'html' in result:
            result = result['html']
            #result = json.loads(result)

    if code == -2:
        data_login = {
            'userName': 'huagong1',
            'password': 'huagong1'
        }
        token_data = Post(url_login, data_login)
        const_token = token_data['token']
        data_sql = {
            'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
            'token': const_token,
            'nl2sql': str_sql
        }
        result = Post(url_sql, data_sql)
        code = result['code']
        if 'result' in result and result['result'] is not None:
            result = result['result']      
            if 'html' in result:
                result = result['html']
                #result = json.loads(result)
    final_result = {
        'answer': [result],
    }  
    return jsonify(final_result)

if __name__ == '__main__':
    # 这里host是你的后端地址，这里写0.0.0.0， 表示的是这个接口在任何服务器上都可以被访问的到，只需要前端访问该服务器地址就可以的，
    # 当然你也可以写死，如222.222.222.222， 那么前端只能访问222.222.222.222, port是该接口的端口号,
    # debug = True ,表示的是，调试模式，每次修改代码后不用重新启动服务
    # app.run(host='0.0.0.0', port=5000, debug=True)
    # 上线部署时运行的端口为5003
    app.run(host='0.0.0.0', port=5003, debug=False)

    # 第一次查询耗时长个2s左右
    # query = "东南地区上个月的MQL个数是多少?"
    # response = main_function(query)
    # print(response)
    # response = main_function(query)
    # print(response)


